#!/usr/bin/env python
import math
import os
import time

import numpy as np
import torch

from jieba_tokenizer import MyTokenizer, text

# 创建分词器对象
from model import GPTConfig, GPT

warmup_iters = 1
lr_decay_iters = 2
min_lr = 6e-5
iter_num = 0
decay_lr = True
eval_interval = 2
master_process = True


# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)


always_save_checkpoint = False

device = "mps"
out_dir = "out"
block_size = 12
batch_size = 3
# =================#
n_layer = 12
n_head = 12
n_embd = 768
dropout = 0.0  # for pretraining 0 is good, for finetuning try 0.1+
bias = False  # do we use bias inside LayerNorm and Linear layers?

# model init
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
                  bias=bias, vocab_size=None, dropout=dropout)  # start with model_args from command line

eval_iters = 10
max_iters = 20
model_args['vocab_size'] = 310
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
model.to(device)
# max_iters = 600000  # total number of training iterations
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
# adamw optimizer
learning_rate = 6e-4  # max learning rate
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device)
scaler = torch.cuda.amp.GradScaler()

tokenizer = MyTokenizer()
tokenizer.build_dict(text)
data = tokenizer.encode(text)


def get_batch(split='train'):
    data = tokenizer.encode(text)
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy(np.array(data[i:i + block_size])) for i in ix])
    y = torch.stack([torch.from_numpy(np.array(data[i + 1:i + 1 + block_size])) for i in ix])
    return x.to(device), y.to(device)


@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch('train')
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


eval_only = False
# ddp通常指的是分布式数据并行（Distributed Data Parallel）

raw_model = model

gradient_accumulation_steps = 5 * 8
grad_clip = 1.0
log_interval = 1
local_iter_num = 0

X, Y = get_batch('train')
t0 = time.time()
running_mfu = -1.0
best_val_loss = 1e9

while True:
    # determine and set the learning rate for this iteration
    lr = get_lr(iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    # todo 补全
    # evaluate the loss on train/val sets and write checkpoints
    if iter_num % eval_interval == 0 and master_process:
        losses = estimate_loss()
        print(f"step {iter_num}: train loss {losses['train']:.4f}")
        if losses['val'] < best_val_loss or always_save_checkpoint:
            best_val_loss = losses['val']
            if iter_num > 0:
                checkpoint = {
                    'model': raw_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'model_args': model_args,
                    'iter_num': iter_num,
                    'best_val_loss': best_val_loss,
                    'config': None,
                }
                print(f"saving checkpoint to {out_dir}")
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
    if iter_num == 0 and eval_only:
        break

    # forward backward update, with optional gradient accumulation to simulate larger batch size
    # and using the GradScaler if data type is float16
    for micro_step in range(gradient_accumulation_steps):
        logits, loss = model(X, Y)
        loss = loss / gradient_accumulation_steps  # scale the loss to account for gradient accumulation
        # immediately async prefetch next batch while model is doing the forward pass on the GPU
        X, Y = get_batch('train')
        # backward pass, with gradient scaling if training in fp16
        scaler.scale(loss).backward()
    # clip the gradient
    if grad_clip != 0.0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # step the optimizer and scaler if training in fp16
    scaler.step(optimizer)
    scaler.update()
    # flush the gradients as soon as we can, no need for this memory anymore
    optimizer.zero_grad(set_to_none=True)

    # timing and logging
    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if iter_num % log_interval == 0 and master_process:
        # get loss as float. note: this is a CPU-GPU sync point
        # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
        lossf = loss.item() * gradient_accumulation_steps
        if local_iter_num >= 5:  # let the training loop settle a bit
            mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
            running_mfu = mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu
        print(f"iter {iter_num}: loss {lossf:.4f}, time {dt * 1000:.2f}ms, mfu {running_mfu * 100:.2f}%")
    iter_num += 1
    local_iter_num += 1

    # termination conditions
    if iter_num > max_iters:
        break

init_from="resume"
if init_from == 'resume':
    # init from a model saved in a specific directory
    ckpt_path = os.path.join(out_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)

model.eval()


max_new_tokens = 12

start_ids = tokenizer.encode("女人")
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
temperature = 1  # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 12  # retain only the top_k most likely tokens, clamp others to have 0 probability
# run generation
with torch.no_grad():
    for k in range(1):
        y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
        tolist = y[0].tolist()
        tokenizer_decode = tokenizer.decode(tolist)
        print(tokenizer_decode)

if __name__ == '__main__':
    print("==========over==========")
